"""
Main scripts to start experiments.
Takes a flag --env-type (see below for choices) and loads the parameters from the respective config file.
"""
import argparse
import warnings

import numpy as np
import torch
import wandb

# get configs

from config import args_easy_treasure_varibad, args_medi_treasure_varibad, args_hard_treasure_varibad
from config import args_easy_treasure_rl2, args_medi_treasure_rl2, args_hard_treasure_rl2
from config import args_bandit_rl2, args_bandit_varibad
from config import args_easy_world_rl2

from config.gridworld import \
    args_grid_belief_oracle, args_grid_rl2, args_grid_varibad
from config.pointrobot import \
    args_pointrobot_multitask, args_pointrobot_varibad, args_pointrobot_rl2, args_pointrobot_humplik
from config.mujoco import \
    args_cheetah_dir_multitask, args_cheetah_dir_expert, args_cheetah_dir_rl2, args_cheetah_dir_varibad, \
    args_cheetah_vel_multitask, args_cheetah_vel_expert, args_cheetah_vel_rl2, args_cheetah_vel_varibad, \
    args_cheetah_vel_avg, \
    args_ant_dir_multitask, args_ant_dir_expert, args_ant_dir_rl2, args_ant_dir_varibad, \
    args_ant_goal_multitask, args_ant_goal_expert, args_ant_goal_rl2, args_ant_goal_varibad, \
    args_ant_goal_humplik, \
    args_walker_multitask, args_walker_expert, args_walker_avg, args_walker_rl2, args_walker_varibad, \
    args_humanoid_dir_varibad, args_humanoid_dir_rl2, args_humanoid_dir_multitask, args_humanoid_dir_expert
from environments.parallel_envs import make_vec_envs
from learner import Learner
from metalearner import MetaLearner


from prep_args import C, set_run_parameters

hyper_parameters = {'--ppo_num_minibatch': C((2, 4)),
                  '--ppo_clip_param':C((0.05, 0.1,)),
                  '--policy_num_steps' : C((100, 200)),
                  '--vae_batch_num_trajs' : C((10, 25)),
                  '--num_vae_updates' : C((1, 3)),
                  '--kl_weight' : C((0.1, 1)),
                  '--num_frames' : C((int(2e7),))}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default=None)
    parser.add_argument('--wandb_id', type=str, default=None)
    parser.add_argument('--n', type=int, default=None)
    args, rest_args = parser.parse_known_args()
    env = args.env_type
    n = args.n
    wandb_id = args.wandb_id
    
    set_run_parameters(hyper_parameters, n)
    sub_lists = [[key, str(hyper_parameters[key])] for key in hyper_parameters.keys()]
    h_args = [val for slist in sub_lists for val in slist]
    
    h_args.append("--exp_label")
    h_args.append(f"grid{n}")

    # --- First-Explore Stochastic-Bandit ---
#     ./config/gridworld/args_grid_varibad.py
    
    if env == 'bandit_rl2' or env == 'mean_bandit_rl2' or env == 'control_bandit_rl2':
        args = args_bandit_rl2.get_args(h_args)
        if env == 'bandit_rl2':
            args.env_name = "StochasticBandit-v0"
        if env == 'mean_bandit_rl2':
            args.env_name = "StochasticMeanBandit-v0"
        if env == 'control_bandit_rl2':
            args.env_name = "StochasticControlBandit-v0"
        
    elif env == 'bandit_vb' or env == "mean_bandit_vb" or env == 'control_bandit_vb':
        args = args_bandit_varibad.get_args(h_args)
        if env == "mean_bandit_vb":
            args.env_name = "StochasticMeanBandit-v0"
        if env == 'control_bandit_vb':
            args.env_name = "StochasticControlBandit-v0"
    else:
        raise Exception("Invalid Environment")

    # warning for deterministic execution
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError('If you want fully deterministic code, run it with num_processes=1.'
                               'Warning: This will slow things down and might break A2C if '
                               'policy_num_steps < env._max_episode_steps.')

    # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1]
    if args.norm_actions_pre_sampling or args.norm_actions_post_sampling:
        envs = make_vec_envs(env_name=args.env_name, seed=0, num_processes=args.num_processes,
                             gamma=args.policy_gamma, device='cpu',
                             episodes_per_task=args.max_rollouts_per_task,
                             normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                             tasks=None,
                             )
        assert np.unique(envs.action_space.low) == [-1]
        assert np.unique(envs.action_space.high) == [1]

    # clean up arguments
    if args.disable_metalearner or args.disable_decoder:
        args.decode_reward = False
        args.decode_state = False
        args.decode_task = False

    if hasattr(args, 'decode_only_past') and args.decode_only_past:
        args.split_batches_by_elbo = True
    # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes:
    #     args.split_batches_by_elbo = True

    # begin training (loop through all passed seeds)
    seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
    for seed in seed_list:
        print('training', seed)
        args.seed = seed
        args.action_space = None
        
        config = dict(args.__dict__)
        config.update({'env_type' : env})
        wandb.init(project='VBad', id=str(seed)+str(wandb_id), resume="allow",
                   sync_tensorboard=True, config = config)
        if args.disable_metalearner:
            # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`.
            # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
            learner = Learner(args)
        else:
            learner = MetaLearner(args)
        learner.train()
        wandb.finish()


if __name__ == '__main__':
    main()
